145d560d4874bbef905df415e02f0925fc325a63,src/edu/stanford/nlp/sentiment/SentimentCostAndGradient.java,SentimentCostAndGradient,calculate,#number[]#,51
Before Change
binaryTD = TwoDimensionalMap.treeMap();
// stands for Classification Derivatives
TwoDimensionalMap<String, String, SimpleMatrix> binaryCD;
binaryCD = TwoDimensionalMap.treeMap();
// word vector derivatives
Map<String, SimpleMatrix> wordVectorD = Generics.newTreeMap();
After Change
// binaryTD stands for Transform Derivatives (see the SentimentModel)
TwoDimensionalMap<String, String, SimpleMatrix> binaryTD = TwoDimensionalMap.treeMap();
// binaryCD stands for Classification Derivatives
TwoDimensionalMap<String, String, SimpleMatrix> binaryCD = TwoDimensionalMap.treeMap();
// unaryCD stands for Classification Derivatives
Map<String, SimpleMatrix> unaryCD = Generics.newTreeMap();
// word vector derivatives
Map<String, SimpleMatrix> wordVectorD = Generics.newTreeMap();
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : model.binaryTransform) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
// The derivative matrix has one row for each class. The number
// of columns in the derivative matrix is the same as the number
// of rows in the original transform matrix
binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(model.numClasses, numRows));
}
for (Map.Entry<String, SimpleMatrix> entry : model.unaryClassification.entrySet()) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
unaryCD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
}
for (Map.Entry<String, SimpleMatrix> entry : model.wordVectors.entrySet()) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
wordVectorD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
}
// TODO: This part can easily be parallelized
List<Tree> forwardPropTrees = Generics.newArrayList();
for (Tree tree : trainingBatch) {
Tree trainingTree = tree.deepCopy();
// this will attach the error vectors and the node vectors
// to each node in the tree
forwardPropagateTree(trainingTree);
forwardPropTrees.add(trainingTree);
}
// TODO: we may find a big speedup by separating the derivatives and then summing
double error = 0.0;
for (Tree tree : forwardPropTrees) {
backpropDerivativesAndError(tree, binaryTD, binaryCD, unaryCD, wordVectorD);
error += sumError(tree);
}
value = error;
derivative = RNNUtils.paramsToVector(theta.length, binaryTD.valueIterator(), binaryCD.valueIterator(), unaryCD.values().iterator(), wordVectorD.values().iterator());
}
private void backpropDerivativesAndError(Tree tree,